from BasicLibraries import *
import MatrixPLT
from matplotlib import rcParams
from scipy.stats import median_test



plt.rcParams['xtick.major.pad']='8'
rcParams['font.family'] = 'serif'
rcParams['font.sans-serif'] = ['Verdana']
color4 = ['indianred', 'navy', '#4FD284', 'orange']




MP = MatrixPLT.MatrixPlot()


class MakeEffectAnalysis():
    def __init__(self):
        pass

    def BehavioralDifference(self, XBtm, XTp, environmental_sdgs, initiatives_list = [], percentage = False, dpi_ = 50, vmn=-4,vmx=2.5, ctr=0, show_plot=True, save_plot=False, plot_name=''):
        XBottom = MP.MakeMat(XBtm, sdgs_list = environmental_sdgs, initiatives_list=initiatives_list, percentage = percentage, rounded = 1,show_plot=False, sdg = environmental_sdgs, dpi_ = 50)
        XTop = MP.MakeMat(XTp, sdgs_list = environmental_sdgs, initiatives_list=initiatives_list, percentage = percentage, rounded = 1,show_plot=False, sdg = environmental_sdgs, dpi_ = 50)
        XBottom = XBottom.loc[XTop.index]
        XBottom = XBottom[XTop.columns]
        if show_plot:
            difference_matrix = XBottom-XTop
            MP.SimpleMatrixPlot(difference_matrix, dpi_ = dpi_, vmn=vmn,vmx=vmx, ctr=ctr)
            if save_plot:
                plt.savefig('Figures/Figure_bh_diff'+plot_name+'.pdf', bbox_inches = 'tight', dpi = 200) 
        return XBottom, XTop

    def StatSignificance(self, initiatives, envSDG, XBtm, XTp):

        res = []
        for k in initiatives:
            i = [k+' - SDG '+str(S) for S in envSDG]       
            A = XBtm[i].sum(axis = 1)
            B = XTp[i].sum(axis = 1)
            res.append([k,np.round(ttest_ind(A,B)[1],2), np.round(median_test(A, B)[1],2)]) 
        action_significance = pd.DataFrame(res, columns = ['action', 'p-value 1', 'p-value 2'])
        resS = []
        for S in envSDG:
            i = [I+' - SDG '+str(S) for I in initiatives]       
            A = XBtm[i].sum(axis = 1)
            B = XTp[i].sum(axis = 1)
            resS.append([S,np.round(ttest_ind(A,B)[1],2), np.round(median_test(A, B)[1],2)])
        sdg_significance = pd.DataFrame(resS, columns = ['SDG', 'p-value 1', 'p-value 2'])
        return action_significance, sdg_significance
    
    #=======================================
    #=== Alignment with the Paris Agreement
    #=======================================
            
    def subsetConstributionToSectorTotal(self, DT, sub_set):
        #=== Take on the full sample
        x = DT[['GICS_level_0', 'rfyear', 'DirectControl']].groupby(['rfyear', 'GICS_level_0']).sum()
        x = x.reset_index().pivot('rfyear','GICS_level_0',  'DirectControl')
        #== Take on the subset
        du = DT[DT.gvkey.isin(sub_set.gvkey.unique())]
        xu = du[['GICS_level_0', 'rfyear', 'DirectControl']].groupby(['rfyear', 'GICS_level_0']).sum()
        xu = xu.reset_index().pivot('rfyear','GICS_level_0',  'DirectControl')
        #== Total emissions (of the full sample) in the sectors represented in the subsample
        tx = x[xu.columns].sum(axis = 1)
        for i in xu.columns:
            xu[i]/=tx
        xu = xu[xu.columns]
        return xu
    def ContributionParis(self, DT, DTF, alig, misal, ax):    
        algn = DT[DT.gvkey.isin(alig.gvkey.unique())]
        misl = DT[DT.gvkey.isin(misal.gvkey.unique())]
        aln_contribution = self.subsetConstributionToSectorTotal(DTF, algn)*100
        msl_contribution = self.subsetConstributionToSectorTotal(DTF, misl)*100
        aln_contribution.columns = ['']*len(DT.GICS_level_0.unique())
        aln_contribution.plot(kind = 'bar', stacked = True, color = color4, ax = ax, alpha = 0.65, rot = 0, position = -0.05, width = 0.2, legend = False)
        msl_contribution.plot(kind = 'bar', stacked = True, color = color4, ax = ax, alpha = 0.65, rot = 0, position = 1.05, width = 0.2)
        handles, labels = ax.get_legend_handles_labels()
        plt.legend(handles[::-1], labels[::-1], loc = 'center', bbox_to_anchor = (0.5, 1.1), ncol= 4, prop={'size': 22})
        plt.xlabel('')
        plt.ylabel('Contribution to \n sector emissions, %', fontsize = 24)
        plt.margins(x=0)
        plt.ylim(0,100)
        plt.text(0,32, 'A.', color ='deepskyblue', weight='bold')
        plt.text(-0.2,65, 'M.', color ='mediumseagreen', weight='bold')
        ax.get_xticklabels()[-1].set_color("red")
    def AlignmentPastEmissions(self, algn, msl, ax, VAR_ = 'em_intensity'):
        X    =   algn[['rfyear', VAR_]].sort_values(by = 'rfyear')
        X['Population'] = ['Aligned']*len(X)
        Y    =   msl[['rfyear', VAR_]].sort_values(by = 'rfyear')
        Y['Population'] = ['Misaligned']*len(Y)
        Z = pd.concat((X, Y))
        my_pal = {"Misaligned": "mediumseagreen", "Aligned": "deepskyblue"}
        if VAR_ == 'DirectControl':
            Z[VAR_]/=1000000
        sns.boxplot(x = 'rfyear', y = VAR_, hue = 'Population', data = Z, 
                    palette = my_pal, showfliers=False, ax = ax,)
        if VAR_ == 'DirectControl':
            plt.ylabel(r'GHG Emissions, MtCO$_2$e', fontsize = 24)
        elif VAR_ == 'em_intensity':
            plt.ylabel(r'Emission intensity', fontsize = 24)

        plt.xlabel('', fontsize = 28)
        plt.legend(loc = 'center', bbox_to_anchor = (0.5,1.1), title = '' ,ncol = 2)
        ax.get_xticklabels()[-1].set_color("red")
    def PARISStatistics(self, DT, DTF, alig, misal, algn, msl, show_plot=True, save_plot = False, FIGNAME = '' ):
        if show_plot:
            plt.figure(figsize = (28,6))
            ax = plt.subplot(121)
            self.AlignmentPastEmissions(algn, msl, ax, VAR_ = 'DirectControl')
            ax = plt.subplot(122)
            self.AlignmentPastEmissions(algn, msl, ax, VAR_ = 'em_intensity')        
            if save_plot:
                plt.savefig('Figures/em_stat_'+FIGNAME+'.pdf', bbox_inches = 'tight', dpi = 200) 

    def ExcessEffort(self, XBt, XTp, action_significance_level, sdg_significance_level, show_plot=True, save_plot=False, plot_name='PARIS'):

        diff_matrix = (XBt - XTp).drop(columns = ['Total'], index = ['Total'])
        sdgs_ = diff_matrix.sum(axis = 0).sort_values(ascending = False).reset_index()
        sdg_significance_level = sdg_significance_level.loc[sdgs_['index']]
        ini_ = diff_matrix.sum(axis = 1).sort_values(ascending = False).reset_index()
        action_significance_level = action_significance_level.loc[ini_['index']]

        sdgs_.columns = ['SDGs', 'Excess effort, %']
        ini_.columns = ['SDGs', 'Excess effort, %']
        if show_plot:
            #===
            plt.figure(figsize = (24,8))
            ax = plt.subplot(121)
            g = sns.barplot(x = 'Excess effort, %', y = 'SDGs', data = ini_, palette='coolwarm', ax = ax)
            plt.axvline(x = 0, ls = '-.', c = 'k', lw = 1.5)
            plt.ylabel('')
            plt.xlabel('Excess effort, %', fontsize = 28)
            count = 0
            sgn=False
            if sgn:
                for tick_label in g.axes.get_yticklabels():
                    if action_significance_level['p-value 1'].iloc[count] < 0.1:
                        tick_label.set_color("black")
                        tick_label.set_fontsize("30")
                    else:
                        tick_label.set_color("gray")
                        tick_label.set_fontsize("30")
                    count+=1
            
            ax = plt.subplot(122)

            g = sns.barplot(x = 'Excess effort, %', y = 'SDGs', data = sdgs_, palette='coolwarm', ax = ax)
            plt.axvline(x = 0, ls = '-.', c = 'k', lw = 1.5)
            plt.ylabel('')
            plt.xlabel('Excess effort, %', fontsize = 28)
            count = 0
            sgn=False
            if sgn:
                for tick_label in g.axes.get_yticklabels():
                    if sdg_significance_level['p-value 1'].iloc[count] < 0.1:
                        tick_label.set_color("black")
                        tick_label.set_fontsize("30")
                    else:
                        tick_label.set_color("gray")
                        tick_label.set_fontsize("30")
                    count+=1
            plt.tight_layout()        
            if save_plot:
                plt.savefig('Figures/FigureExcess'+plot_name+'.pdf', bbox_inches = 'tight', dpi = 200) 
        return ini_, sdgs_

    ### These are the functions to print the table in latex for Table 1
    def excess_effort_table_FIXED(self, X):
        #=== color table
        a = X.apply(np.sign)
        a = a.replace([1], [r'\mycircle{blue}'])
        a = a.replace([-1], [r'\mycircle{red}'])
        a = a.replace([0], [r'\mycircle{white}'])
        return a
    def excess_effort_table(self, X):
        A = X.copy()
        for i in range(len(A)):
            for j in A.columns:
                if A[j].iloc[i] > 0:
                    A[j].iloc[i] = r'\mycircle{blue!'+str(int(A[j].iloc[i]*30))+'}'
                else:
                    A[j].iloc[i] = r'\mycircle{red!'+str(abs(int(A[j].iloc[i]*30)))+'}'
        return A
        
        
        
        